昨天我們提到使用策略迭代的方式,讓我們可以實作計算狀態價值。並給出了一個 GridWorld 的範例,今天我們將要使用策略迭代,在這個情況下實作計算狀態價值。
昨天已經介紹過 GridWorld ,因此今天不多做介紹,直接看程式碼。首先,載入套件並設定 GridWorld 環境變數:
# packages
import numpy as np
import time, os
# environment setting
prob_action = np.full(4,0.25)
func_value = np.zeros(16)
func_reward = np.full(16,-1)
func_reward[0] = 0
func_reward[15] = 0
num_actions = 4
num_states = 16
T = np.load('./gridworld/T.npy')
上述的程式碼中,我們設定了
接著設定演算法中的參數,設定如下:
# parameters
delta = 0.1
gamma = 0.99
theta = 0.05
counter = 1
這裡我們設定了四個參數,分別是
最後是迭代的程式碼:
# iterativa policy evaluation
while delta > theta:
func_value_now = func_value.copy()
for state in range(1,15):
prob_next_state = prob_action*T[:, state, :]
future_reward = func_reward + func_value_now*gamma
func_value[state] = np.sum(np.matmul(np.transpose(prob_next_state), future_reward))
delta = np.max(np.abs(func_value - func_value_now))
os.system('cls' if os.name == 'nt' else 'clear')
print('='*60)
print('[Parameters]')
print('Gamma = ' + str(gamma))
print('Threshold = ' + str(theta) + '\n')
print('[Variables]')
print('No.' + str(counter) + ' iteration')
print('Delta = ' + str(delta) + '\n')
print('[State-Value]')
print(func_value.reshape(4,4))
print('='*60)
counter += 1
time.sleep(1)
這邊要注意一下,因為我使用的環境是 Ubuntu 18.04 LTS,第 8 行清除終端機的指令,在 Windows 中可能沒有 (我沒有確認)。
透過迭代,我們更新狀態的價值,在達成中止條件之前,會不斷更新。經過 50 次迭代後,達成中止條件,結果如下:
============================================================
[Parameters]
Gamma = 0.99
Thershold = 0.05
[Variables]
No.50 iteration
Delta = 0.04759090371885932
[State-Value]
[[ 0. -10.62647251 -15.48618269 -17.07138964]
[-10.62647251 -13.90097573 -15.50704743 -15.48618269]
[-15.48618269 -15.50704743 -13.90097573 -10.62647251]
[-17.07138964 -15.48618269 -10.62647251 0. ]]
============================================================
至此,我們實現計算狀態價值的方法,這裡提供完整程式碼。在實作計算方法後,有些東西可以做額外的測試,如: